# Copyright 2025 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains building blocks for cropping."""

import random

import numpy as np
import torch
from biotite.structure import Atom, AtomArray
from numpy.random import Generator, default_rng
from scipy.spatial.distance import cdist

from openfold3.core.data.primitives.quality_control.logging_utils import (
    log_runtime_memory,
)
from openfold3.core.data.primitives.structure.interface import (
    get_query_interface_token_center_atoms,
)
from openfold3.core.data.primitives.structure.labels import (
    assign_atom_indices,
    remove_atom_indices,
)

NO_CROPPING_TOKEN_BUDGET_SENTINEL = -1


def crop_contiguous(
    atom_array: AtomArray, token_budget: int, generator: Generator | None = None
) -> None:
    """Implements Contiguous Cropping from AF3 SI, 2.7.1.

    Uses Algorithm 1 from AF-Multimer section 7.2.1. to update the input biotite
    atom array with added 'crop_mask' annotation in-place. Note: Algorithm 1
    does not work correctly as stated in the AF-Multimer SI, so here we are using
    a fixed version.

    Args:
        atom_array (atom_array):
            Biotite atom array of the first bioassembly of a PDB entry.
        token_budget (int):
            Token budget i.e. total crop size.
        generator (Optional[Generator]):
            A numpy generator set with a specific seed.

    Returns:
        None
    """

    if generator is None:
        seed = random.randint(0, torch.iinfo(torch.int32).max)
        generator = default_rng(seed=seed)

    # Assign atom index
    assign_atom_indices(atom_array)

    # Get chain ids and permute
    chains = np.array(sorted(set(atom_array.chain_id)))
    chains = generator.permutation(chains)

    # Create cropping mask annotation
    atom_array.set_annotation("crop_mask", np.repeat(False, len(atom_array)))

    # Cropping loop
    # "number of tokens selected so far"
    n_added = 0
    # combined length of yet to be cropped chains excluding current"
    n_remaining = len(set(atom_array.token_id))

    for chain_id in chains:
        # Get chain atom array
        atom_array_chain = atom_array[atom_array.chain_id == chain_id]

        # Get chain length
        chain_length = atom_array_chain.token_id[-1] - atom_array_chain.token_id[0] + 1
        n_remaining -= chain_length

        # Sample length of crop for current chain
        crop_size_max = min(token_budget - n_added, chain_length)
        crop_size_min = min(chain_length, max(0, token_budget - n_added - n_remaining))
        crop_size = generator.integers(crop_size_min, crop_size_max + 1, 1).item()

        n_added += crop_size

        # Sample start of crop for current chain
        crop_start = generator.integers(0, chain_length - crop_size + 1, 1).item()

        # Get token indices in the crop
        chain_token_ids = np.array(sorted(list(set(atom_array_chain.token_id))))
        # Slice using the sampled crop start and length for this chain
        crop_token_index_chain = chain_token_ids[crop_start : crop_start + crop_size]
        # Map to atom indices in the full assembly
        crop_atom_index_chain = atom_array_chain[
            np.isin(atom_array_chain.token_id, crop_token_index_chain)
        ]._atom_idx

        # Edit corresponding segment in crop mask
        atom_array.crop_mask[crop_atom_index_chain] = True

    # Remove atom index
    remove_atom_indices(atom_array)


def crop_spatial(
    atom_array: AtomArray,
    token_budget: int,
    generator: Generator | None = None,
    preferred_chain_or_interface: str | list[str, str] | None = None,
) -> None:
    """Implements Spatial Cropping from AF3 SI, 2.7.2.

    Uses Algorithm 2 from AF-Multimer section 7.2.2. to update the input biotite
    atom array with added 'crop_mask' annotation in-place. Note: we drop the
    index-based distance-untying step from Algorithm 2 (line 1, i * 10^-3 factor)
    because it distorts the distances and results in less convex spatial crops.

    Args:
        atom_array (AtomArray):
            Biotite atom array of the first bioassembly of a PDB entry.
        token_budget (int):
            Total crop size.
        generator (Generator | None):
            A numpy generator set with a specific seed.
        preferred_chain_or_interface (str | list[str, str] | None):
            Integer or integer 2-tuple indicating the preferred chain or interface,
            respectively, from which reference atoms are selected. Generated by
            eq. 1 in AF3 SI for the weighted PDB dataset.

    Returns:
        None
    """

    if generator is None:
        seed = random.randint(0, torch.iinfo(torch.int32).max)
        generator = default_rng(seed=seed)

    # Subset token center atoms to those in the preferred chain/interface if provided
    token_center_atoms, preferred_token_center_atoms = fetch_token_center_atoms(
        atom_array, preferred_chain_or_interface
    )

    # Get reference atom
    reference_atom = generator.choice(preferred_token_center_atoms, size=1)[0]

    # Find spatial crop
    find_spatial_crop(reference_atom, token_center_atoms, token_budget, atom_array)


def crop_spatial_interface(
    atom_array: AtomArray,
    token_budget: int,
    generator: Generator | None = None,
    preferred_chain_or_interface: str | list[str, str] | None = None,
) -> None:
    """Implements Spatial Interface Cropping from AF3 SI, 2.7.3.

    Uses Algorithm 2 from AF-Multimer section 7.2.2. to update the input biotite
    atom array with added 'crop_mask' annotation in-place. Note: we drop the
    index-based distance-untying step from Algorithm 2 (line 1, i * 10^-3 factor)
    because it distorts the distances and results in less convex spatial crops.

    Args:
        atom_array (AtomArray):
            Biotite atom array of the first bioassembly of a PDB entry.
        token_budget (int):
            Total crop size.
        generator (Generator | None):
            A numpy generator set with a specific seed.
        preferred_chain_or_interface (str | list[str, str] | None):
            Integer or integer 2-tuple indicating the preferred chain or interface,
            respectively, from which reference atoms are selected. Generated by
            eq. 1 in AF3 SI for the weighted PDB dataset.

    Returns:
        None
    """

    if generator is None:
        seed = random.randint(0, torch.iinfo(torch.int32).max)
        generator = default_rng(seed=seed)

    # Subset token center atoms to those in the preferred chain/interface if provided
    token_center_atoms, preferred_token_center_atoms = fetch_token_center_atoms(
        atom_array, preferred_chain_or_interface
    )

    # Skip interface subsetting if there is only one chain
    # Making the interface spatial crop equivalent to non-interface spatial crop
    if len(set(atom_array.chain_id)) > 1:
        # Find interface token center atoms
        preferred_interface_token_center_atoms = get_query_interface_token_center_atoms(
            preferred_token_center_atoms, token_center_atoms
        )

        # Get reference atom
        reference_atom = generator.choice(
            preferred_interface_token_center_atoms, size=1
        )[0]
    else:
        # Get reference atom
        reference_atom = generator.choice(preferred_token_center_atoms, size=1)[0]

    # Find spatial crop
    find_spatial_crop(reference_atom, token_center_atoms, token_budget, atom_array)


def fetch_token_center_atoms(
    atom_array: AtomArray,
    preferred_chain_or_interface: str | list[str, str] | None,
) -> tuple[AtomArray, AtomArray]:
    """Returns the token center atoms in an atom array.

    Also returns a subset of token center atoms which are in the preferred chain or
    interface.

    Args:
        atom_array (AtomArray):
            AtomArray of the input assembly.
        preferred_chain_or_interface (str | list[str, str] | None):
            Integer or integer 2-tuple indicating the preferred chain or interface,
            respectively, from which reference atoms are selected. Generated by eq. 1 in
            AF3 SI for the weighted PDB dataset.

    Raises:
        ValueError:
            Invalid preferred_chain_or_interface: {preferred_chain_or_interface}, has to
            be int or tuple.

    Returns:
        tuple[AtomArray, AtomArray]:
            Tuple of all and preferred token center atoms. Note that the preferred token
            center atoms are subset to only resolved atoms.
    """
    token_center_atoms = atom_array[atom_array.token_center_atom]

    # Subset to resolved token center atoms
    token_center_atoms = token_center_atoms[token_center_atoms.occupancy > 0]

    if len(token_center_atoms) == 0:
        raise RuntimeError(
            "Cannot crop a structure with no resolved token center atoms."
        )

    if preferred_chain_or_interface is not None:
        # If chain provided
        if isinstance(preferred_chain_or_interface, str):
            preferred_token_center_atoms = token_center_atoms[
                token_center_atoms.chain_id == preferred_chain_or_interface
            ]
        # If interface provided
        elif isinstance(preferred_chain_or_interface, list):
            preferred_token_center_atoms = token_center_atoms[
                np.isin(token_center_atoms.chain_id, preferred_chain_or_interface)
            ]
        else:
            raise ValueError(
                f"""Invalid preferred_chain_or_interface: \
                 {preferred_chain_or_interface}, has to be str or 2-list."""
            )
    else:
        preferred_token_center_atoms = token_center_atoms

    # If the preferred chain/interface has no resolved atoms, use all resolved token
    # center atoms
    # Note: this will also be the case if a chain or interface is provided that is not
    # in the structure
    if len(preferred_token_center_atoms) == 0:
        preferred_token_center_atoms = token_center_atoms

    return token_center_atoms, preferred_token_center_atoms


def find_spatial_crop(
    reference_atom: Atom,
    token_center_atoms: AtomArray,
    token_budget: int,
    atom_array: AtomArray,
) -> None:
    """Finds the token_budget number of closes atoms to the reference atom.

    Args:
        reference_atom (Atom):
            The sampled reference atom around which the spatial crop is created.
        token_center_atoms (AtomArray):
            The set of token center atoms to crop from.
        token_budget (int):
            Crop size.
        atom_array (AtomArray):
            Input atom array of the bioassembly.

    Returns:
        None
    """
    # Get distance from all other token center atoms and break ties
    distances_to_reference_atom = cdist(
        np.reshape(reference_atom.coord, (1, -1)), token_center_atoms.coord
    )[0, :]

    # Get token_budget nearest token center atoms
    nearest_token_center_atom_ids = np.argsort(distances_to_reference_atom)[
        :token_budget
    ]

    # Get all atoms for nearest token center atoms
    atom_array.set_annotation(
        "crop_mask",
        np.isin(
            atom_array.token_id,
            token_center_atoms[nearest_token_center_atom_ids].token_id,
        ),
    )


CROP_REGISTRY = {
    "contiguous": (crop_contiguous, ("atom_array", "token_budget")),
    "spatial": (
        crop_spatial,
        ("atom_array", "token_budget", "preferred_chain_or_interface"),
    ),
    "spatial_interface": (
        crop_spatial_interface,
        ("atom_array", "token_budget", "preferred_chain_or_interface"),
    ),
}


def sample_crop_strategy(crop_weights: dict[str, float]) -> str:
    """Samples cropping strategy with dataset-specific weights.

    Args:
        crop_weights (dict[str, float]):
            Dictionary of crop weights.

    Returns:
        str:
            Sampled cropping strategy.
    """
    crop_keys = list(CROP_REGISTRY.keys())

    return crop_keys[
        torch.multinomial(
            torch.tensor([crop_weights[c] for c in crop_keys], dtype=torch.float),
            num_samples=1,
            replacement=False,
        ).item()
    ]


@log_runtime_memory(runtime_dict_key="runtime-target-structure-proc-crop")
def sample_crop_and_set_mask(
    atom_array: AtomArray,
    apply_crop: bool,
    crop_config: dict,
    preferred_chain_or_interface: str | list[str, str] | None,
) -> str:
    """Samples cropping strategy and sets the crop mask.

    Running this function on an AtomArray will add the 'crop_mask' annotation in-place,
    which is True for atoms inside the crop and False for atoms outside the crop.

    Args:
        atom_array (AtomArray):
            AtomArray of the input assembly to crop.
        apply_crop (bool):
            Whether to apply cropping.
        crop_config (dict):
            Crop configuration dictionary, containing the following keys:
             - token_budget (int): Number of tokens to sample.
             - crop_weights (dict): Weights of different crop strategies.
        preferred_chain_or_interface (str | list[str, str] | None):
            Integer or integer 2-tuple indicating the preferred chain or interface,
            respectively, from which reference atoms are selected. Generated by eq. 1 in
            AF3 SI for the weighted PDB dataset.
    Returns:
        str:
            Name of the sampled cropping strategy. Returns 'whole' if
            the whole assembly fits into the token budget. Should be one of
            ['contiguous', 'spatial', 'spatial_interface', 'whole'].
    """

    # Take whole assembly if shouldn't crop or it fits in the budget
    if (not apply_crop) or (
        len(set(atom_array.token_id)) <= crop_config["token_budget"]
    ):
        atom_array.set_annotation("crop_mask", np.repeat(True, len(atom_array)))
        return "whole"
    # Otherwise crop
    else:
        crop_strategy = sample_crop_strategy(crop_config["crop_weights"])
        crop_function, crop_function_argnames = CROP_REGISTRY[crop_strategy]
        crop_input = {
            "atom_array": atom_array,
            "token_budget": crop_config["token_budget"],
            "preferred_chain_or_interface": preferred_chain_or_interface,
        }
        crop_function(
            **{k: v for k, v in crop_input.items() if k in crop_function_argnames}
        )
        return crop_strategy
